package com.rpc.learn.transport;

import com.rpc.learn.regist.ServiceRegister;
import com.rpc.learn.transport.dto.RpcRequest;
import com.rpc.learn.transport.dto.RpcResponse;
import com.rpc.learn.transport.handler.CustomObjectDecoder;
import com.rpc.learn.transport.handler.CustomObjectEncoder;
import com.rpc.learn.transport.handler.NettyRpcClientHandler;
import com.rpc.learn.transport.support.ChannelProvider;
import com.rpc.learn.transport.support.UnprocessedRequests;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.net.InetSocketAddress;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;

@Slf4j
public class NettyRpcClient {

    private final UnprocessedRequests unprocessedRequests = new UnprocessedRequests();
    private final ChannelProvider channelProvider = new ChannelProvider();
    private final Bootstrap bootstrap;
    private final EventLoopGroup eventLoopGroup;

    private static NettyRpcClient nettyRpcClient;

    private NettyRpcClient() {
        // initialize resources such as EventLoopGroup, Bootstrap
        eventLoopGroup = new NioEventLoopGroup();
        bootstrap = new Bootstrap();
        bootstrap.group(eventLoopGroup)
                .channel(NioSocketChannel.class)
                .handler(new LoggingHandler(LogLevel.INFO))
                //  The timeout period of the connection.
                //  If this time is exceeded or the connection cannot be established, the connection fails.
                .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
                .handler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel ch) {
                        ChannelPipeline p = ch.pipeline();
                        // If no data is sent to the server within 15 seconds, a heartbeat request is sent
//                        p.addLast(new IdleStateHandler(0, 5, 0, TimeUnit.SECONDS));

//                        p.addLast(new ObjectEncoder());
//                        p.addLast(new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.cacheDisabled(null)));

                        p.addLast(new CustomObjectEncoder());
                        p.addLast(new CustomObjectDecoder());
                        p.addLast(new NettyRpcClientHandler(unprocessedRequests));
                    }
                });
    }

    public static NettyRpcClient getNettyRpcClient() {
        if(nettyRpcClient == null) {
            nettyRpcClient = new NettyRpcClient();
        }
        return nettyRpcClient;
    }

    public CompletableFuture sendRequest(InetSocketAddress address, String serviceName, Object[] params) {
        Class[] paramTypes = null;
        if(params != null && params.length > 0) {
            paramTypes = new Class[params.length];
            for (int i = 0; i < params.length; i++) {
                Object param = params[i];
                paramTypes[i] = param.getClass();
            }
        }

        // build return value
        CompletableFuture<RpcResponse> resultFuture = new CompletableFuture<>();
        //获取连接
        Channel channel = getChannel(address);
        RpcRequest rpcRequest = RpcRequest.builder()
                .requestId(UUID.randomUUID().toString())
                .parameters(params)
                .serviceName(serviceName)
                .paramTypes(paramTypes)
                .build();
        unprocessedRequests.put(rpcRequest.getRequestId(), resultFuture);
        channel.writeAndFlush(rpcRequest).addListener((ChannelFutureListener) future -> {
            if (future.isSuccess()) {
                log.info("client send message: [{}]", rpcRequest);
            } else {
                future.channel().close();
                resultFuture.completeExceptionally(future.cause());
                log.error("Send failed:", future.cause());
            }
        });
        return resultFuture;
    }

    @SneakyThrows
    private Channel getChannel(InetSocketAddress address) {
        Channel channel = channelProvider.get(address);
        if(channel == null) {
            synchronized (this) {
                channel = channelProvider.get(address);
                if(channel == null) {
                    //do connect
                    CompletableFuture<Channel> completableFuture = new CompletableFuture<>();
                    bootstrap.connect(address).addListener((ChannelFutureListener) future -> {
                        if (future.isSuccess()) {
                            log.info("The client has connected [{}] successful!", address.toString());
                            completableFuture.complete(future.channel());
                        } else {
                            throw new IllegalStateException();
                        }
                    });
                    channel = completableFuture.get();
                    channelProvider.put(address, channel);
                }
            }
        }
        return channel;
    }
}
